import torch
import torch.nn as nn
from torch import optim
from eval import eval_net
from model.segnet import SegNet
from utils.dataset import BasicDataset
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter


writer = SummaryWriter('log')
dir_img = 'data/train/images'
dir_mask = 'data/train/masks'
lr=0.001
epochs = 500
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 32

def train(net):
    dataset = BasicDataset(dir_img, dir_mask)
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True,drop_last=True)
    val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=True)

    # optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' , patience=2)
    optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-8)
    criterion = nn.BCEWithLogitsLoss()

    epoch_loss = 0
    cur_itrs = 0
    for epoch in range(epochs):
        net.train()
        for images,mask in train_loader:
            imgs = images.to(device=device, dtype=torch.float32)
            mask_type = torch.float32
            true_masks = mask.to(device=device, dtype=mask_type)
            masks_pred = net(imgs)
            loss = criterion(masks_pred, true_masks)
            epoch_loss += loss.item()
            print("Epoch", epoch, loss.item())

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_value_(net.parameters(), 0.1)
            optimizer.step()

            cur_itrs+=1

        mean_loss = epoch_loss / cur_itrs

        val_score = eval_net(net, val_loader, device)
        # scheduler.step(val_score)
        print('ACC:', val_score)

        writer.add_scalar('LOSS', mean_loss, global_step=cur_itrs)
        writer.add_scalar('ACC', val_score, global_step=cur_itrs)
        if epoch % 10 == 0:
            torch.save(net.state_dict(),'weight/net_segnet.pth')
            torch.save(net, 'save/net_segnet_%s.pth' % epoch)


if __name__ == '__main__':
    net = SegNet(3,1)
    # net = nn.DataParallel(net)
    net.to(device=device)
    net.load_state_dict(torch.load('weight/net_segnet.pth', map_location=device), False)
    train(net)

